[TritonGPU] Split MemDescSubview into MemDescIndex and MemDescSubslice#7622
Conversation
The first one will be used just for pipelining and it's equivalent to `x[i]`, the second one takes a full slice of constant shape `x[:i1, :i2]`, for example.
ThomasRaoux
left a comment
There was a problem hiding this comment.
Awesome great cleanup! Few minor comments
| %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> | ||
| // expected-remark @below {{%2 -> %0}} | ||
| %0 = ttg.memdesc_subview %cst[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> | ||
| %0 = ttg.memdesc_subslice %cst {offsets=array<i32: 0, 0>} : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> |
There was a problem hiding this comment.
nit: we can probably have a better printing like %cts[0, 0] but it can be done as a follow up
There was a problem hiding this comment.
vibecoded it with an agent in 5 min in c89d9c6
I didn't even know about that custom API...
| bool is1D = | ||
| srcTy.getRank() == 1 && dstTy.getRank() == 1 && dstTy.getDimSize(0) == 1; |
There was a problem hiding this comment.
when do we need the 1d case?
There was a problem hiding this comment.
when we pipeline barriers and things like that.
There was a problem hiding this comment.
In Gluon we do Nx1xi64 to get around having to support this in the APIs. Changing that in the compiler however would mean needing to update a LOT of tests...
There was a problem hiding this comment.
after this PR I'm not scared of having to change a lot of tests (in reality it was horrible)
on a different note, this would be a lovely task for an agent
There was a problem hiding this comment.
yeah would be nice to clean up
#7622 introduced `ttg.memdesc_index` which applies a constant offset to the base pointer of the smem object. For padded layouts we need to add padding based on the offset, similar to what #7404 did for the old subview operation. I also adjusted the lit test to check we actually generate padding from the ttg.memdesc_index. The previous version did not fail because it matched the lowering of the `ttg.local_load/store` as well.
…7696) triton-lang#7622 introduced `ttg.memdesc_index` which applies a constant offset to the base pointer of the smem object. For padded layouts we need to add padding based on the offset, similar to what triton-lang#7404 did for the old subview operation. I also adjusted the lit test to check we actually generate padding from the ttg.memdesc_index. The previous version did not fail because it matched the lowering of the `ttg.local_load/store` as well.
…7696) triton-lang#7622 introduced `ttg.memdesc_index` which applies a constant offset to the base pointer of the smem object. For padded layouts we need to add padding based on the offset, similar to what triton-lang#7404 did for the old subview operation. I also adjusted the lit test to check we actually generate padding from the ttg.memdesc_index. The previous version did not fail because it matched the lowering of the `ttg.local_load/store` as well.
third_party/tlx/run_all.sh [TLX-3.5] Fix memdesc_subview refactoring from triton-lang#7622 pytest python/test/unit/language/test_tlx.py::test_load_store_smem_with_tl_load pytest python/test/unit/language/test_tlx.py::test_local_store pytest python/test/unit/language/test_tlx.py::test_local_load TODO. fix TLX layout propagation LITs using memdesc_subview [TLX-3.5] Fix barrier ops caused by 1D tensor handling by memdesc_index python/test/unit/language/test_tlx.py::test_wait_arrive_non_ws The root cause is memdesc_index fail its `verify()` for 1D tensor case. It's caused by a bug in merging conflicts. More related discussions: https://github.com/triton-lang/triton/pull/7622/files#r2227788997 [TLX-3.5] Fix all UTs python/test/unit/language/test_tlx.py::test_async_dot
The first one will be used just for pipelining and it's equivalent to
x[i], the second one takes a full slice of constant shapex[:i1, :i2],for example.